{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "import time\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "# os.chdir('../')\n",
    "from mmcv.cnn.bricks.transformer import MultiheadAttention\n",
    "# from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func\n",
    "# from my_projects.CMT.cmt.models.utils.flash_attention import FlashMHA\n",
    "from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttnFunction\n",
    "from mmengine.model import BaseModule, ModuleList\n",
    "from mmengine.model import xavier_init, uniform_init, constant_init"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DeformableAttention2MultiModality(BaseModule):\n",
    "    \n",
    "    def __init__(self, embed_dims=256, num_heads=8, num_points=4, dropout=0.1, im2col_step=64, batch_first=True):\n",
    "        super(DeformableAttention2MultiModality, self).__init__()\n",
    "        if embed_dims % num_heads != 0:\n",
    "            raise ValueError(f'embed_dims ({embed_dims}) must be divisible by num_heads ({num_heads})')\n",
    "        assert batch_first is True\n",
    "        self.embed_dims = embed_dims\n",
    "        self.num_heads = num_heads\n",
    "        self.num_points = num_points\n",
    "        self.sampling_offsets = nn.Linear(embed_dims, num_heads * num_points * 3)\n",
    "        self.cam_embedding = nn.Sequential(\n",
    "            nn.Linear(12, embed_dims // 2),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(embed_dims // 2, embed_dims),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.LayerNorm(embed_dims)\n",
    "        )\n",
    "        self.pts_attention_weights = nn.Linear(embed_dims, num_heads * num_points)\n",
    "        self.img_attention_weights = nn.Linear(embed_dims, num_heads * num_points)\n",
    "        self.pts_proj = nn.Linear(embed_dims, embed_dims)\n",
    "        self.img_proj = nn.Linear(embed_dims, embed_dims)\n",
    "        self.output_proj = nn.Linear(2 * embed_dims, embed_dims)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.im2col_step = im2col_step\n",
    "        self.init_weights()\n",
    "\n",
    "    def init_weights(self):\n",
    "        uniform_init(self.sampling_offsets, -2.0, 2.0)\n",
    "        constant_init(self.pts_attention_weights, 0.0)\n",
    "        constant_init(self.img_attention_weights, 0.0)\n",
    "        xavier_init(self.pts_proj)\n",
    "        xavier_init(self.img_proj)\n",
    "        xavier_init(self.output_proj)\n",
    "\n",
    "    def forward(self, query, query_pos, pts_feat, pts_pos, img_feat, img_pos, \n",
    "                reference_points, pc_range, img_metas):\n",
    "        \"\"\"Forward function for `DeformableAttention2MultiModality`.\n",
    "        Args:\n",
    "            query, query_pos (Tensor): [bs, num_query, embed_dims]\n",
    "            pts_feat, pts_pos (Tensor): [bs, pts_l, pts_w, embed_dims]\n",
    "            img_feat, img_pos (Tensor): [bs * num_cam, img_h, img_w, embed_dims]\n",
    "            reference_points (Tensor): [bs, num_query, 3]\n",
    "            pc_range (Tensor): [6]\n",
    "            img_metas (dict): meta information, must contain 'lidar2img' 'pad_shape'\n",
    "        Returns:\n",
    "            Tensor: [bs, num_query, embed_dims]\n",
    "        \"\"\"\n",
    "        bs, num_query, embed_dims = query.shape\n",
    "        pts_l, pts_w = pts_feat.shape[1], pts_feat.shape[2]\n",
    "        img_h, img_w = img_feat.shape[1], img_feat.shape[2]\n",
    "        assert embed_dims == self.embed_dims\n",
    "        num_cam = img_feat.shape[0] // bs\n",
    "        assert num_cam == len(img_metas[0]['lidar2img'])\n",
    "        assert pts_feat.shape[0] == img_feat.shape[0] // num_cam == bs\n",
    "        assert query.shape == query_pos.shape\n",
    "        assert pts_feat.shape == pts_pos.shape\n",
    "        assert img_feat.shape == img_pos.shape\n",
    "\n",
    "        # project pts_feat, img_feat\n",
    "        pts_feat = self.pts_proj(pts_feat + pts_pos).view(bs, pts_l * pts_w, self.num_heads, -1)\n",
    "        img_feat = self.img_proj(img_feat + img_pos).view(bs * num_cam, img_h * img_w, self.num_heads, -1)\n",
    "\n",
    "        # get sampling offsets and attention\n",
    "        identity = query\n",
    "        query = query + query_pos\n",
    "        sampling_offsets = self.sampling_offsets(query).view(\n",
    "            bs, num_query, self.num_heads, self.num_points, 3)\n",
    "        pts_attention_weights = self.pts_attention_weights(query).view(\n",
    "            bs, num_query, self.num_heads, 1, self.num_points).softmax(dim=-1)\n",
    "        lidars2imgs = np.stack([meta['lidar2img'] for meta in img_metas])\n",
    "        lidars2imgs = torch.from_numpy(lidars2imgs).float().to(query.device)    # [bs, num_cam, 4, 4]\n",
    "        cam_embedding = self.cam_embedding(lidars2imgs[..., :3, :].flatten(-2)) # [bs, num_cam, embed_dims]\n",
    "        query_cam = query.unsqueeze(1) + cam_embedding.unsqueeze(2)             # [bs, num_cam, num_query, embed_dims]\n",
    "        img_attention_weights = self.img_attention_weights(query_cam).view(\n",
    "            bs * num_cam, num_query, self.num_heads, 1, self.num_points).softmax(dim=-1)\n",
    "\n",
    "        # get pts sampling points\n",
    "        reference_points = reference_points * (pc_range[3:6] - pc_range[0:3]) + pc_range[0:3]\n",
    "        sampling_points = reference_points.unsqueeze(2).unsqueeze(3) + sampling_offsets # [bs, num_query, num_heads, num_points, 3]\n",
    "        sampling_points = torch.cat([sampling_points, torch.ones_like(sampling_points[..., :1])], dim=-1)\n",
    "        pts_points = sampling_points[..., :2]\n",
    "        pts_points[..., 0] = pts_points[..., 0] / (pc_range[3] - pc_range[0])\n",
    "        pts_points[..., 1] = pts_points[..., 1] / (pc_range[4] - pc_range[1])\n",
    "        pts_points = pts_points.view(bs, num_query, self.num_heads, 1, self.num_points, 2).contiguous()\n",
    "\n",
    "        # get img sampling points\n",
    "        img_points = torch.matmul(lidars2imgs[:, :, None, None, None], sampling_points[:, None, ..., None]).squeeze(-1)\n",
    "        img_points = img_points[..., :2] / torch.clamp(img_points[..., 2:3], min=1e-5)\n",
    "        img_points[..., 0] = img_points[..., 0] / img_metas[0]['pad_shape'][1]\n",
    "        img_points[..., 1] = img_points[..., 1] / img_metas[0]['pad_shape'][0]\n",
    "        # img_point_mask = (img_points[..., 0] >= 0) & (img_points[..., 0] <= 1) & (img_points[..., 1] >= 0) & (img_points[..., 1] <= 1) & (img_points[..., 2] > 0)\n",
    "        img_points = img_points.view(bs*num_cam, num_query, self.num_heads, 1, self.num_points, 2).contiguous()\n",
    "        \n",
    "        # get pts, img features\n",
    "        out_pts = MultiScaleDeformableAttnFunction.apply(\n",
    "            pts_feat, torch.tensor([[pts_l, pts_w]]).to(query.device), torch.tensor([0]).to(query.device), pts_points, pts_attention_weights, self.im2col_step)\n",
    "        out_img = MultiScaleDeformableAttnFunction.apply(\n",
    "            img_feat, torch.tensor([[img_h, img_w]]).to(query.device), torch.tensor([0]).to(query.device), img_points, img_attention_weights, self.im2col_step)\n",
    "        \n",
    "        # get output\n",
    "        out_img, _ = out_img.view(bs, num_cam, num_query, embed_dims).transpose(1, 2).max(dim=2)\n",
    "        output = torch.cat([out_pts, out_img], dim=-1)\n",
    "        output = self.output_proj(output)\n",
    "        output = self.dropout(output) + identity\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "regular_attention = MultiheadAttention(embed_dims=256, num_heads=8, batch_first=True).cuda()\n",
    "deformable_attention = DeformableAttention2MultiModality(embed_dims=256, num_heads=8).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "lidar2img1 = np.array([[ 3.97734714e+02, -5.06729679e+02, -5.42572583e+01,\n",
    "            2.74828461e+01],\n",
    "          [ 2.46238289e+02, -3.78748864e+00, -5.36753266e+02,\n",
    "           -5.95638454e+01],\n",
    "          [ 9.91205872e-01, -1.70923051e-02, -1.31220294e-01,\n",
    "           -3.00785658e-02],\n",
    "          [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,\n",
    "            1.00000000e+00]])\n",
    "lidar2img2 = np.array([[ 3.97710728e+02, -5.06737747e+02, -5.42431648e+01,\n",
    "           -7.30179615e+01],\n",
    "          [ 2.46215876e+02, -3.78133911e+00, -5.36727283e+02,\n",
    "           -5.95612466e+01],\n",
    "          [ 9.91180349e-01, -1.70659016e-02, -1.31184252e-01,\n",
    "           -3.00764817e-02],\n",
    "          [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00,\n",
    "            1.00000000e+00]])\n",
    "lidar2img = [lidar2img1, lidar2img2]\n",
    "img_metas = [{'lidar2img': lidar2img, 'pad_shape': (640, 800)}]\n",
    "pc_range = torch.tensor([-21.0, -21.0, -2.0, 21.0, 21.0, 6.0]).cuda()\n",
    "query = torch.randn(1, 500, 256).cuda()\n",
    "query_pos = torch.randn(1, 500, 256).cuda()\n",
    "pts_feat = torch.randn(1, 70, 70, 256).cuda()\n",
    "pts_pos = torch.randn(1, 70, 70, 256).cuda()\n",
    "img_feat = torch.randn(2, 40, 50, 256).cuda()\n",
    "img_pos = torch.randn(2, 40, 50, 256).cuda()\n",
    "reference_points = torch.randn(1, 500, 3).cuda()\n",
    "\n",
    "key = torch.randn(1, 70*70 + 2*40*50, 256).cuda()\n",
    "value = torch.randn(1, 70*70 + 2*40*50, 256).cuda()\n",
    "key_pos = torch.randn(1, 70*70 + 2*40*50, 256).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "regular_attention time: 4.763118028640747\n"
     ]
    }
   ],
   "source": [
    "time1 = time.time()\n",
    "for i in range(1000):\n",
    "    output1 = regular_attention(query, key, value, query_pos=query_pos, key_pos=key_pos)\n",
    "print('regular_attention time:', time.time() - time1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "deformable_attention time: 0.11110138893127441\n"
     ]
    }
   ],
   "source": [
    "time2 = time.time()\n",
    "for i in range(1000):\n",
    "    output2 = deformable_attention(query, query_pos, pts_feat, pts_pos, img_feat, img_pos, reference_points, pc_range, img_metas)\n",
    "print('deformable_attention time:', time.time() - time2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create an instance of the regular attention layer\n",
    "regular_attention = MultiheadAttention(embed_dims=256, num_heads=8).cuda()\n",
    "\n",
    "# Create an instance of the deformable attention layer\n",
    "deformable_attention = MultiScaleDeformableAttention(embed_dims=256, num_heads=8).cuda()\n",
    "\n",
    "# Create an instance of the flash attention layer\n",
    "flash_attention = FlashMHA(embed_dim=256, num_heads=8).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.3918044567108154\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "with torch.no_grad():\n",
    "    query = torch.randn(1024, 1, 256, device=torch.device('cuda'))\n",
    "    attn_mask = torch.zeros(1024, 1024, device=torch.device('cuda')).bool()\n",
    "    for _ in range(1000):\n",
    "        out = regular_attention(query, attn_mask=attn_mask)\n",
    "print(time.time() - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9290800094604492\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "with torch.no_grad():\n",
    "    query = torch.randn(1024, 1, 256, device=torch.device('cuda'))\n",
    "    reference_points = torch.randn(16, 1024, 1, 2, device=torch.device('cuda'))\n",
    "    spatial_shapes = torch.tensor([[32, 32]], device=torch.device('cuda'))\n",
    "    level_start_index = torch.tensor([0], device=torch.device('cuda'))\n",
    "    for _ in range(1000):\n",
    "        out = deformable_attention(query, reference_points=reference_points, \n",
    "            spatial_shapes=spatial_shapes, level_start_index=level_start_index)\n",
    "print(time.time() - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9855546951293945\n"
     ]
    }
   ],
   "source": [
    "start = time.time()\n",
    "with torch.no_grad():\n",
    "    query = torch.randn(1024, 1, 256, device=torch.device('cuda'))\n",
    "    for _ in range(1000):\n",
    "        out = flash_attention(query, query, query)\n",
    "print(time.time() - start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
